#!/usr/bin/python

from pylab import *
import argparse,glob,os,os.path
from scipy.ndimage import filters,interpolation,morphology
from scipy import stats
import multiprocessing



parser = argparse.ArgumentParser()
parser.add_argument('-t','--threshold',type=float,default=0.5,help='threshold, determines lightness')
parser.add_argument('-z','--zoom',type=float,default=0.5,help='zoom for page background estimation, smaller=faster')
parser.add_argument('-e','--escale',type=float,default=1.0,help='scale for estimating a mask over the text region')
parser.add_argument('-b','--bignore',type=float,default=0.1,help='ignore this much of the border for threshold estimation')
parser.add_argument('-p','--perc',type=float,default=80,help='percentage for filters')
parser.add_argument('-r','--range',type=float,default=20,help='range for filters')
parser.add_argument('-m','--maxskew',type=float,default=2,help='skew angle estimation parameters (degrees)')
parser.add_argument('-g','--gray',action='store_true',help='force grayscale processing even if image seems binary')
parser.add_argument('--lo',type=float,default=5,help='percentile for black estimation')
parser.add_argument('--hi',type=float,default=90,help='percentile for white estimation')
parser.add_argument('--skewsteps',type=int,default=8,help='steps for skew angle estimation (per degree)')
parser.add_argument('--debug',type=float,default=0,help='display intermediate results')
parser.add_argument('--show',action='store_true',help='display final result')
parser.add_argument('-o','--output',default=None,help="output directory")
parser.add_argument('files',nargs='+')
parser.add_argument('-Q','--parallel',type=int,default=0)
args = parser.parse_args()

files = []
for pattern in args.files:
    expanded = glob.glob(pattern)
    assert len(expanded)>0
    files += expanded
args.files = files



def estimate_scale(binary):
    objects = binary_objects(binary)
    bysize = sorted(objects,key=A)
    scalemap = zeros(binary.shape)
    for o in bysize:
        if amax(scalemap[o])>0: continue
        scalemap[o] = A(o)**0.5
    scale = median(scalemap[(scalemap>3)&(scalemap<100)])
    return scale

def estimate_skew_angle(image,angles):
    estimates = []
    for a in angles:
        v = mean(interpolation.rotate(image,a,order=0,mode='constant'),axis=1)
        v = var(v)
        estimates.append((v,a))
    if args.debug>0:
        plot([y for x,y in estimates],[x for x,y in estimates])
        ginput(1,args.debug)
    _,a = max(estimates)
    return a
    
def select_regions(binary,f,min=0,nbest=100000):
    labels,n = measurements.label(binary)
    objects = measurements.find_objects(labels)
    scores = [f(o) for o in objects]
    best = argsort(scores)
    keep = zeros(len(objects)+1,'B')
    for i in best[-nbest:]:
        if scores[i]<=min: continue
        keep[i+1] = 1
    return keep[labels]

def H(s): return s[0].stop-s[0].start
def W(s): return s[1].stop-s[1].start
def A(s): return W(s)*H(s)

def dshow(image,info):
    if args.debug<=0: return
    ion(); gray(); imshow(image); title(info); ginput(1,args.debug)

def process1(job):
    fname,i = job
    comment = ""
    image = imread(fname)
    if image.ndim==3: image = mean(image,2)
    dshow(image,"input")
    # perform image normalization
    image -= amin(image)
    image /= amax(image)

    # check whether the image is already effectively binarized
    if args.gray:
        extreme = 0
    else:
        extreme = (sum(image<0.05)+sum(image>0.95))*1.0/prod(image.shape)
    if extreme>0.95:
        comment += " no-normalization"
        flat = image
    else:
        # if not, we need to flatten it by estimating the local whitelevel
        m = interpolation.zoom(image,args.zoom)
        m = filters.percentile_filter(m,args.perc,size=(args.range,2))
        m = filters.percentile_filter(m,args.perc,size=(2,args.range))
        m = interpolation.zoom(m,1.0/args.zoom)
        if args.debug>0: clf(); imshow(m,vmin=0,vmax=1); ginput(1,args.debug)
        w,h = minimum(array(image.shape),array(m.shape))
        flat = clip(image[:w,:h]-m[:w,:h]+1,0,1)
        if args.debug>0: clf(); imshow(flat,vmin=0,vmax=1); ginput(1,args.debug)

    # estimate skew angle and rotate
    if args.maxskew>0:
        d0,d1 = flat.shape
        o0,o1 = int(args.bignore*d0),int(args.bignore*d1)
        flat = amax(flat)-flat
        flat -= amin(flat)
        est = flat[o0:d0-o0,o1:d1-o1]
        ma = args.maxskew
        ms = int(2*args.maxskew*args.skewsteps)
        angle = estimate_skew_angle(est,linspace(-ma,ma,ms+1))
        flat = interpolation.rotate(flat,angle,mode='constant',reshape=0)
        flat = amax(flat)-flat
    else:
        angle = 0

    # estimate low and high thresholds
    d0,d1 = flat.shape
    o0,o1 = int(args.bignore*d0),int(args.bignore*d1)
    est = flat[o0:d0-o0,o1:d1-o1]
    if args.escale>0:
        # by default, we use only regions that contain
        # significant variance; this makes the percentile
        # based low and high estimates more reliable
        e = args.escale
        v = est-filters.gaussian_filter(est,e*20.0)
        v = filters.gaussian_filter(v**2,e*20.0)**0.5
        v = (v>0.3*amax(v))
        v = morphology.binary_dilation(v,structure=ones((e*50,1)))
        v = morphology.binary_dilation(v,structure=ones((1,e*50)))
        if args.debug>0: imshow(v); ginput(1,args.debug)
        est = est[v]
    lo = stats.scoreatpercentile(est.ravel(),args.lo)
    hi = stats.scoreatpercentile(est.ravel(),args.hi)

    # rescale the image to get the gray scale image
    flat -= lo
    flat /= (hi-lo)
    flat = clip(flat,0,1)
    if args.debug>0: imshow(flat,vmin=0,vmax=1); ginput(1,args.debug)
    bin = 1*(flat>args.threshold)

    # output the normalized grayscale and the thresholded images
    print fname,"lo-hi (%.2f %.2f) angle %4.1f"%(lo,hi,angle),comment
    if args.debug>0 or args.show: clf(); gray();imshow(bin); ginput(1,max(0.1,args.debug))
    gray()
    if args.output:
        if not os.path.exists(args.output):
            os.mkdir(args.output)
        imsave(args.output+"/%04d.png"%i,flat)
        imsave(args.output+"/%04d.bin.png"%i,bin)
    else:
        base,_ = os.path.splitext(fname)
        imsave(base+".nrm.png",flat)
        imsave(base+".bin.png",bin)

import ocrolib

if args.debug>0 or args.show>0: args.parallel = 0

if args.output:
    if not os.path.exists(args.output):
        os.mkdir(args.output)

if args.parallel<2:
    for i,f in enumerate(args.files): 
        process1((f,i+1))
else:
    pool = multiprocessing.Pool(processes=args.parallel)
    jobs = []
    for i,f in enumerate(args.files): jobs += [(f,i+1)]
    result = pool.map(process1,jobs)
